{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pretraining using SlimPajama\n",
    "\n",
    "Let's see how we can use the data generated from the [data pipeline notebook](./data_pipeline.ipynb) to pretrain a model. All we need to do is define the data module based on the generated data and replace it with the mock data module provided by default in the [NeMo LLM recipes](../../../nemo/collections/llm/recipes/__init__.py)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nemo_run as run\n",
    "from typing import Optional\n",
    "import lightning.pytorch as pl\n",
    "from nemo.collections import llm\n",
    "from nemo.collections.common.tokenizers import SentencePieceTokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the data module\n",
    "To define the data module, we can use `llm.PreTrainingDataModule` and pass in the data paths and tokenizer. In case you don't have either of the two, please refer to the [data pipeline notebook](./data_pipeline.ipynb). You can also look at the definition of the data module for the other parameters supported like `split`, `num_workers`, `index_mapping_dir`, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def slimpajama(\n",
    "    gbs: int = 256,\n",
    "    mbs: int = 4,\n",
    "    seq_length: int = 8192,\n",
    ") -> run.Config[pl.LightningDataModule]:\n",
    "\n",
    "    return run.Config(\n",
    "        llm.PreTrainingDataModule,\n",
    "        paths=[\"/data/slimpajama_megatron/concatenated_chunk1.jsonl_text_document\"],\n",
    "        seq_length=seq_length,\n",
    "        global_batch_size=gbs,\n",
    "        micro_batch_size=mbs,\n",
    "        tokenizer=run.Config(SentencePieceTokenizer, model_path=\"/data/tokenizer/tokenizer.model\"),\n",
    "        split=\"99,8,2\",\n",
    "        num_workers=2,\n",
    "        index_mapping_dir=\"/data/index_mapping\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configure the recipe and launch pretraining\n",
    "Once the data module is defined, you can use an existing recipe and replace the data module as shown below.\n",
    "To learn more about the recipes, refer to the [quickstart](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/quickstart.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def configure_recipe(nodes: int = 1, gpus_per_node: int = 1):\n",
    "    recipe = llm.llama3_8b.pretrain_recipe(\n",
    "        dir=\"/checkpoints/llama-new\", # Path to store checkpoints\n",
    "        name=\"llama_pretraining\",\n",
    "        num_nodes=nodes,\n",
    "        num_gpus_per_node=gpus_per_node,\n",
    "    )\n",
    "\n",
    "    recipe.model.config.num_layers = 1\n",
    "    recipe.model.config.hidden_size = 128\n",
    "    recipe.trainer.max_steps = 30\n",
    "    recipe.data = slimpajama(\n",
    "        gbs=32,\n",
    "        mbs=1,\n",
    "    )\n",
    "    recipe.trainer.val_check_interval = 20\n",
    "    recipe.trainer.strategy.context_parallel_size = 1\n",
    "    recipe.log.ckpt.save_optim_on_train_end = True\n",
    "    return recipe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def local_executor_torchrun(nodes: int = 1, devices: int = 1) -> run.LocalExecutor:\n",
    "    # Env vars for jobs are configured here\n",
    "    env_vars = {\n",
    "        \"TORCH_NCCL_AVOID_RECORD_STREAMS\": \"1\",\n",
    "        \"NEMO_ENV_VARNAME_TESTING\": \"1\",\n",
    "        \"CUDA_VISIBLE_DEVICES\": \"0\"\n",
    "    }\n",
    "\n",
    "    executor = run.LocalExecutor(ntasks_per_node=devices, launcher=\"torchrun\", env_vars=env_vars)\n",
    "    return executor\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_pretraining():\n",
    "    recipe = configure_recipe()\n",
    "    executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)\n",
    "\n",
    "    run.run(recipe, executor=executor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Pretraining\n",
    "Now, you can just call the `run_pretraining` function to start pretraining on your local machine using torchrun."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_pretraining()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
